# -*- coding: utf-8 -*-
#
# Copyright 2019 Ricequant, Inc
#
# * Commercial Usage: please contact public@ricequant.com
# * Non-Commercial Usage:
#     Licensed under the Apache License, Version 2.0 (the "License");
#     you may not use this file except in compliance with the License.
#     You may obtain a copy of the License at
#
#         http://www.apache.org/licenses/LICENSE-2.0
#
#     Unless required by applicable law or agreed to in writing, software
#     distributed under the License is distributed on an "AS IS" BASIS,
#     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#     See the License for the specific language governing permissions and
#     limitations under the License.


from rqalpha.interface import AbstractMod
from rqalpha.const import DEFAULT_ACCOUNT_TYPE
from rqalpha import export_as_api

from .account_model import StockAccount, FutureAccount
from .position_model import StockPositionProxy, FuturePositionProxy
from .api import api_future, api_stock, api_base


class AccountMod(AbstractMod):

    def start_up(self, env, mod_config):

        StockPositionProxy.stock_t1 = mod_config.stock_t1
        StockAccount.dividend_reinvestment = mod_config.dividend_reinvestment

        FutureAccount.forced_liquidation = mod_config.future_forced_liquidation

        # 注入 Account
        env.set_account_model(DEFAULT_ACCOUNT_TYPE.STOCK.name, StockAccount)
        env.set_account_model(DEFAULT_ACCOUNT_TYPE.FUTURE.name, FutureAccount)

        # 注入 Position
        env.set_position_model(DEFAULT_ACCOUNT_TYPE.STOCK.name, StockPositionProxy)
        env.set_position_model(DEFAULT_ACCOUNT_TYPE.FUTURE.name, FuturePositionProxy)

        # 注入 API
        for export_name in api_base.__all__:
            export_as_api(getattr(api_base, export_name))

        if DEFAULT_ACCOUNT_TYPE.FUTURE.name in env.config.base.accounts:
            # 注入期货API
            for export_name in api_future.__all__:
                export_as_api(getattr(api_future, export_name))
        if DEFAULT_ACCOUNT_TYPE.STOCK.name in env.config.base.accounts:
            # 注入股票API
            for export_name in api_stock.__all__:
                export_as_api(getattr(api_stock, export_name))

    def tear_down(self, code, exception=None):
        pass
